{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lightly Train - Quick Start\n",
    "\n",
    "This notebook demonstrates how to use Lightly Train for self-supervised learning on image data. It follows the [quick start guide](https://docs.lightly.ai/train/stable/quick_start.html) from the documentation.\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lightly-ai/lightly-train/blob/main/examples/notebooks/quick_start.ipynb)\n",
    "\n",
    "> **Important**: When running on Google Colab make sure to select a GPU runtime for faster processing. You can do this by going to `Runtime` > `Change runtime type` and selecting a GPU hardware accelerator."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly-train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> **Important**: LightlyTrain is officially supported on\n",
    "> - Linux: CPU or CUDA\n",
    "> - MacOS: CPU only\n",
    "> - Windows (experimental): CPU or CUDA\n",
    "> \n",
    "> We are planning to support MPS for MacOS.\n",
    "> \n",
    "> Check the [installation instructions](https://docs.lightly.ai/train/stable/installation.html) for more details on installation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare Data\n",
    "\n",
    "You can use any image dataset for training. No labels are required, and the dataset can be structured in any way, including subdirectories. If you don't have a dataset at hand, you can download a sample clothing dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/lightly-ai/dataset_clothing_images.git my_data_dir\n",
    "!rm -rf my_data_dir/.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "See the [data guide](https://docs.lightly.ai/train/stable/train.html#train-data) for more information on supported data formats."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train\n",
    "\n",
    "Once the data is ready, you can train the model like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import lightly_train\n",
    "\n",
    "# Train the model\n",
    "lightly_train.train(\n",
    "    out=\"out/my_experiment\",  # Output directory\n",
    "    data=\"my_data_dir\",  # Directory with images\n",
    "    model=\"torchvision/resnet18\",  # Model to train\n",
    "    epochs=10,  # Number of epochs to train\n",
    "    batch_size=32,  # Batch size\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> **Note**: This is a minimal example for illustration purposes. In practice you would want to use a larger dataset (>=10'000 images), more epochs (>=100), and a larger batch size (>=128).\n",
    "\n",
    "> **Tip**: LightlyTrain supports many [popular models](https://docs.lightly.ai/train/stable/models/index.html) out of the box.\n",
    "\n",
    "This pretrains a Torchvision ResNet-18 model using images from `my_data_dir`.\n",
    "All training logs, model exports, and checkpoints are saved to the output directory\n",
    "at `out/my_experiment`.\n",
    "\n",
    "Once the training is complete, the `out/my_experiment` directory contains the\n",
    "following files:\n",
    "\n",
    "```text\n",
    "out/my_experiment\n",
    "├── checkpoints\n",
    "│   ├── epoch=99-step=123.ckpt          # Intermediate checkpoint\n",
    "│   └── last.ckpt                       # Last checkpoint\n",
    "├── events.out.tfevents.123.0           # Tensorboard logs\n",
    "├── exported_models\n",
    "|   └── exported_last.pt                # Final model exported\n",
    "├── metrics.jsonl                       # Training metrics\n",
    "└── train.log                           # Training logs\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "The final model is exported to `out/my_experiment/exported_models/exported_last.pt` in\n",
    "the default format of the used library. It can directly be used for\n",
    "fine-tuning. See [export format](https://docs.lightly.ai/train/stable/export.html#format) for more information on how to export\n",
    "models to other formats or on how to export intermediate checkpoints.\n",
    "\n",
    "While the trained model has already learned good representations of the images, it\n",
    "cannot yet make any predictions for tasks such as classification, detection, or\n",
    "segmentation. To solve these tasks, the model needs to be fine-tuned on a labeled\n",
    "dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fine-Tune\n",
    "\n",
    "Now the model is ready for fine-tuning! You can use your favorite library for this step. Below is a simple example using PyTorch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.transforms.v2 as v2\n",
    "import tqdm\n",
    "from torch import nn, optim\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, models, transforms\n",
    "\n",
    "transform = transforms.Compose(\n",
    "    [\n",
    "        v2.Resize((224, 224)),\n",
    "        v2.ToImage(),\n",
    "        v2.ToDtype(torch.float32, scale=True),\n",
    "    ]\n",
    ")\n",
    "dataset = datasets.ImageFolder(root=\"my_data_dir\", transform=transform)\n",
    "dataloader = DataLoader(dataset, batch_size=16, shuffle=True, drop_last=True)\n",
    "\n",
    "# Load the exported model\n",
    "model = models.resnet18()\n",
    "model.load_state_dict(\n",
    "    torch.load(\"out/my_experiment/exported_models/exported_last.pt\", weights_only=True)\n",
    ")\n",
    "\n",
    "# Update the classification head with the correct number of classes\n",
    "model.fc = nn.Linear(model.fc.in_features, len(dataset.classes))\n",
    "\n",
    "device = (\n",
    "    \"cuda\"\n",
    "    if torch.cuda.is_available()\n",
    "    else \"mps\"\n",
    "    if torch.backends.mps.is_available()\n",
    "    else \"cpu\"\n",
    ")\n",
    "model.to(device)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
    "print(\"Starting fine-tuning...\")\n",
    "num_epochs = 10\n",
    "for epoch in range(num_epochs):\n",
    "    running_loss = 0.0\n",
    "    progress_bar = tqdm.tqdm(dataloader, desc=f\"Epoch {epoch + 1}/{num_epochs}\")\n",
    "    for inputs, labels in progress_bar:\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs.to(device))\n",
    "        loss = criterion(outputs, labels.to(device))\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        progress_bar.set_postfix(loss=f\"{loss.item():.4f}\")\n",
    "    print(f\"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Congratulations! You've just trained and fine-tuned a model using Lightly Train!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Embed\n",
    "\n",
    "Instead of fine-tuning the model, you can also use it to generate image embeddings. This is useful for clustering, retrieval, or visualization tasks. The `embed` command generates embeddings for all images in a directory:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import lightly_train\n",
    "\n",
    "lightly_train.embed(\n",
    "    out=\"my_embeddings.pth\",  # Exported embeddings\n",
    "    checkpoint=\"out/my_experiment/checkpoints/last.ckpt\",  # LightlyTrain checkpoint\n",
    "    data=\"my_data_dir\",  # Directory with images\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The embeddings are saved to `my_embeddings.pth`. Let's load them and take a look:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "embeddings = torch.load(\"my_embeddings.pth\")\n",
    "print(\"First five filenames:\")\n",
    "print(embeddings[\"filenames\"][:5])  # Print first five filenames\n",
    "print(\"\\nEmbedding tensor shape:\")\n",
    "print(\n",
    "    embeddings[\"embeddings\"].shape\n",
    ")  # Tensor with embeddings with shape (num_images, embedding_dim)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
